import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt

from models.sdtv3 import sdtv3_s, sdtv3_s_attn, sdtv3_s_channelmlp, sdtv3_s_fullattn, sdtv3_s_splash
from models.vit import VisionTransformer, VisionTransformer_attn, test_vit_attention, vit_tiny_patch16_224
from models.metaformer import poolformerv2_s12, caformer_s18, convformer_s18
from models.qkformer import QKFormer_10_384
from models.MAE_SDT import spikmae_12_512
from models.sdtv3_large import spikformer12_512
from models.sd_former_v1 import sdt
from models.spikformer import vit_snn
from functions.erf import compute_erf, compute_erf_sdt, compute_erf_pool, compute_erf_qk, compute_erf_sdtv1, compute_erf_spikformerv1
from transformers import ViTImageProcessor, ViTForImageClassification
from transformers import AutoModel

from spikingjelly.clock_driven import functional


def visualize_erf(erf_map, title="Effective Receptive Field", file_name="erf.pdf"):
    plt.figure(figsize=(10, 8))
    plt.imshow(erf_map, cmap='grey')
    plt.colorbar(label='Gradient Magnitude')
    plt.title(title)
    plt.axis('off')
    plt.tight_layout()
    # plt.show()

    #save
    plt.savefig(file_name, dpi=300)

def visualize_multierf(erf_maps, base_filename="erf"):
    for layer_name, erf_map in erf_maps.items():
        plt.figure(figsize=(10, 8))
        plt.imshow(erf_map, cmap='grey')
        plt.colorbar(label='Gradient Magnitude')
        plt.title(f"Effective Receptive Field - {layer_name}")
        plt.axis('off')
        plt.tight_layout()
        file_name = f"{base_filename}_{layer_name}.pdf"
        plt.savefig(file_name, dpi=300)
        plt.close()

    n_layers = len(erf_maps)
    rows = int(np.ceil(n_layers / 2))
    cols = min(n_layers, 2)
    
    plt.figure(figsize=(cols * 5, rows * 4))
    for i, (layer_name, erf_map) in enumerate(erf_maps.items()):
        plt.subplot(rows, cols, i + 1)
        plt.imshow(erf_map, cmap='grey')
        plt.colorbar(label='Magnitude')
        plt.title(layer_name)
        plt.axis('off')
    
    plt.tight_layout()
    plt.savefig(f"{base_filename}_all_layers.pdf", dpi=300)
    plt.close()

    avg_erf = np.mean([erf_map for erf_map in erf_maps.values()], axis=0)
    plt.figure(figsize=(10, 8))
    plt.imshow(avg_erf, cmap='grey')
    plt.colorbar(label='Average Gradient Magnitude')
    plt.title("Average Effective Receptive Field Across All Layers")
    plt.axis('off')
    plt.tight_layout()
    plt.savefig(f"{base_filename}_average.pdf", dpi=300)
    plt.close()

# ############## ERF of ViT-Base-16-224 ###############
# if __name__ == "__main__":
#     import os
#     os.environ["CUDA_VISIBLE_DEVICES"] = '1'

#     processor = ViTImageProcessor.from_pretrained('./vit-tiny-16-224')
#     model = ViTForImageClassification.from_pretrained('./vit-tiny-16-224')

#     ## without pretrained
#     model = vit_tiny_patch16_224(pretrained="/data2/users/zhangjy/erf_sdt/vit_model.pth")
    
#     # model = VisionTransformer_attn(
#     #     img_size = 224,
#     #     patch_size = 16,
#     #     in_chans = 3,
#     #     num_classes = 1000,
#     #     global_pool = 'token',
#     #     embed_dim = 192,
#     #     depth=12, 
#     #     num_heads=3,
#     # ) 
#     model.eval()
#     for name, param in model.named_parameters():
#         print(name, param.size())

#     device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
#     model = model.to(device)

#     single_erf = compute_erf(model, image_size=(224, 224), num_runs=20)
#     # visualize_erf(single_erf, title="ERF of ViT-Base-16-224", file_name="erf_vit_base_16_224_w_pretrained.pdf")
#     visualize_multierf(single_erf, base_filename="erf_vit_base_16_224_w_pretrained")

# ############### ERF of SDTV1-8-384 ###############
# if __name__ == "__main__":
    
#     model = sdt(
#         img_size_h=224,
#         img_size_w=224,
#         patch_size=16,
#         in_channels=3,
#         embed_dims=384,
#         num_heads=8,
#         mlp_ratios=4,
#         qkv_bias=True,
#         drop_rate=0.0,
#         attn_drop_rate=0.0,
#         drop_path_rate=0.0,
#         depths=8,
#         sr_ratios=[8, 4, 2],
#         T=4,
#         pooling_stat="1111",
#         attn_mode="direct_xor",
#         spike_mode="lif",
#     ) 
#     model.eval()
#     for name, param in model.named_parameters():
#         print(name, param.size())


#     state_dict = torch.load('sd_former_v1_8_384.pth.tar', map_location='cpu')
#     model.load_state_dict(state_dict["state_dict"], strict=False)

#     device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
#     model = model.to(device)
#     functional.reset_net(model)

#     single_erf = compute_erf_sdtv1(model, image_size=(224, 224), num_runs=90)
#     # visualize_erf(single_erf, title="ERF of SDTV1-8-384", file_name="erf_sdtv1_s_w_pretrained.pdf")
#     visualize_multierf(single_erf, base_filename="SDTV1-8-384")

# ############### ERF of SpikformerV1-8-512 ###############
# if __name__ == "__main__":
    
#     model = vit_snn() 
#     model.eval()
#     for name, param in model.named_parameters():
#         print(name, param.size())

#     state_dict = torch.load('spikformer_checkpoint-308.pth.tar', map_location='cpu')
#     model.load_state_dict(state_dict["state_dict"])

#     device = torch.device('cuda:7' if torch.cuda.is_available() else 'cpu')
#     model = model.to(device)
#     functional.reset_net(model)

#     single_erf = compute_erf_spikformerv1(model, image_size=(224, 224), num_runs=300)
#     # visualize_erf(single_erf, title="ERF of SDTV1-8-384", file_name="erf_spikformerv1_s_w_pretrained.pdf")
#     visualize_multierf(single_erf, base_filename="erf_spikformerv1_s_w_pretrained")

# ############### ERF of SDTV3-S-16-224 ###############
# if __name__ == "__main__":
    
#     model = sdtv3_s()
#     model.eval()
#     for name, param in model.named_parameters():
#         print(name, param.size())

#     state_dict = torch.load('V3_5.1M_1x4.pth', map_location='cpu')
#     model.load_state_dict(state_dict['model'], strict=False)

#     device = torch.device('cuda:7' if torch.cuda.is_available() else 'cpu')
#     model = model.to(device)

#     single_erf = compute_erf_sdt(model, image_size=(224, 224), num_runs=60)
#     visualize_multierf(single_erf, base_filename="erf_sdtv3_s_w_pretrained.pdf")

############### ERF of SDTV3-S-channelmlp ###############
if __name__ == "__main__":
    
    model = sdtv3_s_splash()  
    model.eval()
    for name, param in model.named_parameters():
        print(name, param.size())

    state_dict = torch.load('sdtv3_s_splash_checkpoint-198.pth', map_location='cpu')
    model.load_state_dict(state_dict['model'])

    # device = torch.device('cuda:7' if torch.cuda.is_available() else 'cpu')
    device = torch.device('cpu')
    model = model.to(device)

    single_erf = compute_erf_sdt(model, image_size=(224, 224), num_runs=60)
    visualize_multierf(single_erf, base_filename="erf_sdtv3_s_splash_w_pretrained.pdf")


# ############### ERF of SDTV3-MAE-12-512 ###############
# if __name__ == "__main__":
    
#     model = spikformer12_512()
#     model.eval()
#     for name, param in model.named_parameters():
#         print(name, param.size())

#     # state_dict = torch.load('V3_5.1M_1x4.pth', map_location='cpu')
#     # model.load_state_dict(state_dict['model'], strict=False)

#     device = torch.device('cuda:7' if torch.cuda.is_available() else 'cpu')
#     model = model.to(device)

#     single_erf = compute_erf_sdt(model, image_size=(224, 224), num_runs=500)
#     visualize_erf(single_erf, title="ERF of SDTV3-S-16-224", file_name="erf_sdtv3_s.pdf")

# ############### ERF of ANN Metaformer ###############
# if __name__ == "__main__":
    
#     model = poolformerv2_s12()
#     model.eval()

#     for name, param in model.named_parameters():
#         print(name, param.size())

#     state_dict = torch.load('poolformerv2_s12.pth', map_location='cpu')
#     model.load_state_dict(state_dict)

#     device = torch.device('cuda:4' if torch.cuda.is_available() else 'cpu')
#     model = model.to(device)


#     single_erf = compute_erf_pool(model, image_size=(224, 224), num_runs=60)
#     visualize_multierf(single_erf, base_filename="erf_poolformerv2_s12_w_pretrained")

# ############### ERF of QKFormer ###############
# if __name__ == "__main__":
    
#     model = QKFormer_10_384(T = 4)
#     model.eval()

#     for name, param in model.named_parameters():
#         print(name, param.size())

#     state_dict = torch.load('qk_10_384_checkpoint-199.pth', map_location='cpu')
#     model.load_state_dict(state_dict["model"])
    

#     device = torch.device('cuda:7' if torch.cuda.is_available() else 'cpu')
#     model = model.to(device)

#     # input_path = "/data/dataset/ImageNet/val/n01440764/ILSVRC2012_val_00000293.JPEG"
#     single_erf = compute_erf_qk(model, image_size=(224, 224), num_runs=50, input_path=None)
#     visualize_multierf(single_erf, base_filename="erf_qkformer_10_384_w_pretrained.pdf")